{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fine Tuning BERT-Based-Uncased Hugging Face Model on Kaggle Hate Speech Dataset\n",
    "\n",
    "# [Link to my Youtube Video Explaining this whole Notebook](https://www.youtube.com/watch?v=0Y03waAL4Gw&list=PLxqBkZuBynVTn2lkHNAcw6lgm1MD5QiMK&index=3)\n",
    "\n",
    "[![Imgur](https://imgur.com/dJt4A8B.png)](https://www.youtube.com/watch?v=0Y03waAL4Gw&list=PLxqBkZuBynVTn2lkHNAcw6lgm1MD5QiMK&index=3)\n",
    "---\n",
    "\n",
    "### [Dataset in Kaggle](https://www.kaggle.com/datasets/mrmorj/hate-speech-and-offensive-language-dataset?ref=hackernoon.com)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## First What is BERT?\n",
    "\n",
    "BERT stands for Bidirectional Encoder Representations from Transformers. The name itself gives us several clues to what BERT is all about.\n",
    "\n",
    "BERT architecture consists of several Transformer encoders stacked together. Each Transformer encoder encapsulates two sub-layers: a self-attention layer and a feed-forward layer.\n",
    "\n",
    "### There are two different BERT models:\n",
    "\n",
    "- BERT base, which is a BERT model consists of 12 layers of Transformer encoder, 12 attention heads, 768 hidden size, and 110M parameters.\n",
    "\n",
    "- BERT large, which is a BERT model consists of 24 layers of Transformer encoder,16 attention heads, 1024 hidden size, and 340 parameters.\n",
    "\n",
    "\n",
    "\n",
    "BERT Input and Output\n",
    "BERT model expects a sequence of tokens (words) as an input. In each sequence of tokens, there are two special tokens that BERT would expect as an input:\n",
    "\n",
    "- [CLS]: This is the first token of every sequence, which stands for classification token.\n",
    "- [SEP]: This is the token that makes BERT know which token belongs to which sequence. This special token is mainly important for a next sentence prediction task or question-answering task. If we only have one sequence, then this token will be appended to the end of the sequence.\n",
    "\n",
    "\n",
    "It is also important to note that the maximum size of tokens that can be fed into BERT model is 512. If the tokens in a sequence are less than 512, we can use padding to fill the unused token slots with [PAD] token. If the tokens in a sequence are longer than 512, then we need to do a truncation.\n",
    "\n",
    "And that’s all that BERT expects as input.\n",
    "\n",
    "BERT model then will output an embedding vector of size 768 in each of the tokens. We can use these vectors as an input for different kinds of NLP applications, whether it is text classification, next sentence prediction, Named-Entity-Recognition (NER), or question-answering.\n",
    "\n",
    "\n",
    "------------\n",
    "\n",
    "**For a text classification task**, we focus our attention on the embedding vector output from the special [CLS] token. This means that we’re going to use the embedding vector of size 768 from [CLS] token as an input for our classifier, which then will output a vector of size the number of classes in our classification task.\n",
    "\n",
    "-----------------------\n",
    "\n",
    "![Imgur](https://imgur.com/NpeB9vb.png)\n",
    "\n",
    "-------------------------"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install transformers\n",
    "# !pip install datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install --upgrade --force-reinstall --no-deps transformers\n",
    "# !pip install --upgrade --force-reinstall --no-deps datasets\n",
    "\n",
    "# !pip install --upgrade --force-reinstall --no-deps huggingface_hub\n",
    "\n",
    "# !pip install --upgrade --force-reinstall --no-deps pyarrow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/lib/python3/dist-packages/requests/__init__.py:89: RequestsDependencyWarning: urllib3 (1.26.6) or chardet (3.0.4) doesn't match a supported version!\n",
      "  warnings.warn(\"urllib3 ({}) or chardet ({}) doesn't match a supported \"\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "from datasets import load_dataset,DatasetDict\n",
    "from transformers import AutoTokenizer,TFAutoModelForSequenceClassification\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# DATA_PATH = \"labeled_data.csv\" # In Colab\n",
    "\n",
    "# DATA_PATH = '/kaggle/input/hate-speech-and-offensive-language-dataset/labeled_data.csv' # In Kaggle\n",
    "\n",
    "DATA_PATH = \"../input/labeled_data.csv\" # Local Machine "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Unnamed: 0</th>\n",
       "      <th>count</th>\n",
       "      <th>hate_speech</th>\n",
       "      <th>offensive_language</th>\n",
       "      <th>neither</th>\n",
       "      <th>class</th>\n",
       "      <th>tweet</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>!!! RT @mayasolovely: As a woman you shouldn't...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>!!!!! RT @mleew17: boy dats cold...tyga dwn ba...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>!!!!!!! RT @UrKindOfBrand Dawg!!!! RT @80sbaby...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>!!!!!!!!! RT @C_G_Anderson: @viva_based she lo...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>6</td>\n",
       "      <td>0</td>\n",
       "      <td>6</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>!!!!!!!!!!!!! RT @ShenikaRoberts: The shit you...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Unnamed: 0  count  hate_speech  offensive_language  neither  class  \\\n",
       "0           0      3            0                   0        3      2   \n",
       "1           1      3            0                   3        0      1   \n",
       "2           2      3            0                   3        0      1   \n",
       "3           3      3            0                   2        1      1   \n",
       "4           4      6            0                   6        0      1   \n",
       "\n",
       "                                               tweet  \n",
       "0  !!! RT @mayasolovely: As a woman you shouldn't...  \n",
       "1  !!!!! RT @mleew17: boy dats cold...tyga dwn ba...  \n",
       "2  !!!!!!! RT @UrKindOfBrand Dawg!!!! RT @80sbaby...  \n",
       "3  !!!!!!!!! RT @C_G_Anderson: @viva_based she lo...  \n",
       "4  !!!!!!!!!!!!! RT @ShenikaRoberts: The shit you...  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# pandas_df = pd.read_csv(\"labeled_data.csv\")\n",
    "pandas_df = pd.read_csv(DATA_PATH)\n",
    "pandas_df.head()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/paul/.local/lib/python3.9/site-packages/seaborn/_decorators.py:36: FutureWarning: Pass the following variable as a keyword arg: x. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<AxesSubplot:xlabel='class', ylabel='count'>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZEAAAEJCAYAAABVFBp5AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAWVklEQVR4nO3df7BfdX3n8efLoPaHOIDczcb8aNCNzgB1Y7mDTFkdV1cITNegU1kYK9EyREdodaa7K7rd4qLsuOuvEdalE5eUpGNBWlSyM3FpmlXpOiJJkAIBKVeEJZmQxMQaXS1t8L1/fD9XvoabcDnc7/fLzX0+Zs58z/d9fn3O3Elecz6fc843VYUkSV08b9QNkCTNXoaIJKkzQ0SS1JkhIknqzBCRJHVmiEiSOhtYiCRZnOSrSe5Lsj3J+1r9hCSbkjzYPo9v9SS5OslEkruT/Ebfvla19R9MsqqvflqSe9o2VyfJoM5HkvRUg7wSOQj8QVWdDJwBXJrkZOByYHNVLQM2t+8A5wDL2rQauBZ6oQNcAbwGOB24YjJ42jqX9G23YoDnI0k6xDGD2nFV7QJ2tfkfJbkfWAisBF7fVlsHfA34QKuvr97Tj7cnOS7JgrbupqraD5BkE7AiydeAF1fV7a2+HjgP+MqR2nXiiSfW0qVLZ+o0JWlO2LZt2/erauzQ+sBCpF+SpcCrgW8B81vAADwGzG/zC4FH+zbb0WpHqu+Yon5ES5cuZevWrc/8JCRpDkvyyFT1gQ+sJ3kRcDPw/qo60L+sXXUM/L0rSVYn2Zpk6969ewd9OEmaMwYaIkmeTy9APl9VX2zl3a2biva5p9V3Aov7Nl/UakeqL5qi/hRVtaaqxqtqfGzsKVdjkqSOBnl3VoDrgPur6lN9izYAk3dYrQJu6atf1O7SOgP4Yev2uhU4K8nxbUD9LODWtuxAkjPasS7q25ckaQgGOSZyJvAO4J4kd7Xah4CPATcluRh4BDi/LdsInAtMAD8B3gVQVfuTfATY0ta7cnKQHXgvcD3wy/QG1I84qC5JmlmZa6+CHx8fLwfWJemZSbKtqsYPrfvEuiSpM0NEktSZISJJ6swQkSR1NpQn1qVR+L9X/vqom3DUW/JH94y6CRoxr0QkSZ0ZIpKkzgwRSVJnhogkqTNDRJLUmSEiSerMEJEkdWaISJI6M0QkSZ0ZIpKkzgwRSVJnhogkqTNDRJLU2cBCJMnaJHuS3NtX+0KSu9r08ORvrydZmuSnfcv+uG+b05Lck2QiydVJ0uonJNmU5MH2efygzkWSNLVBXolcD6zoL1TVv6mq5VW1HLgZ+GLf4u9OLquq9/TVrwUuAZa1aXKflwObq2oZsLl9lyQN0cBCpKpuA/ZPtaxdTZwP3HCkfSRZALy4qm6vqgLWA+e1xSuBdW1+XV9dkjQkoxoTeS2wu6oe7KudlOTbSb6e5LWtthDY0bfOjlYDmF9Vu9r8Y8D8gbZYkvQUo/plwwv5xauQXcCSqtqX5DTgy0lOme7OqqqS1OGWJ1kNrAZYsmRJxyZLkg419CuRJMcAbwW+MFmrqseral+b3wZ8F3gFsBNY1Lf5olYD2N26uya7vfYc7phVtaaqxqtqfGxsbCZPR5LmtFF0Z/0r4DtV9fNuqiRjSea1+ZfRG0B/qHVXHUhyRhtHuQi4pW22AVjV5lf11SVJQzLIW3xvAL4JvDLJjiQXt0UX8NQB9dcBd7dbfv8CeE9VTQ7Kvxf4H8AEvSuUr7T6x4A3JXmQXjB9bFDnIkma2sDGRKrqwsPU3zlF7WZ6t/xOtf5W4NQp6vuANz67VkqSng2fWJckdWaISJI6M0QkSZ0ZIpKkzgwRSVJnhogkqTNDRJLUmSEiSerMEJEkdWaISJI6M0QkSZ0ZIpKkzgwRSVJnhogkqTNDRJLUmSEiSerMEJEkdWaISJI6G+RvrK9NsifJvX21DyfZmeSuNp3bt+yDSSaSPJDk7L76ilabSHJ5X/2kJN9q9S8kecGgzkWSNLVBXolcD6yYov7pqlrepo0ASU4GLgBOadv89yTzkswDPgucA5wMXNjWBfgvbV//DPgBcPEAz0WSNIWBhUhV3Qbsn+bqK4Ebq+rxqvoeMAGc3qaJqnqoqv4BuBFYmSTAG4C/aNuvA86byfZLkp7eKMZELktyd+vuOr7VFgKP9q2zo9UOV38J8HdVdfCQuiRpiIYdItcCLweWA7uATw7joElWJ9maZOvevXuHcUhJmhOGGiJVtbuqnqiqnwGfo9ddBbATWNy36qJWO1x9H3BckmMOqR/uuGuqaryqxsfGxmbmZCRJww2RJAv6vr4FmLxzawNwQZIXJjkJWAbcAWwBlrU7sV5Ab/B9Q1UV8FXgt9v2q4BbhnEOkqQnHfP0q3ST5Abg9cCJSXYAVwCvT7IcKOBh4N0AVbU9yU3AfcBB4NKqeqLt5zLgVmAesLaqtrdDfAC4MclHgW8D1w3qXCRJUxtYiFTVhVOUD/sffVVdBVw1RX0jsHGK+kM82R0mSRoBn1iXJHVmiEiSOjNEJEmdGSKSpM4MEUlSZ4aIJKkzQ0SS1JkhIknqzBCRJHVmiEiSOjNEJEmdGSKSpM4MEUlSZ4aIJKkzQ0SS1JkhIknqzBCRJHVmiEiSOhtYiCRZm2RPknv7ah9P8p0kdyf5UpLjWn1pkp8muatNf9y3zWlJ7kkykeTqJGn1E5JsSvJg+zx+UOciSZraIK9ErgdWHFLbBJxaVa8C/hb4YN+y71bV8ja9p69+LXAJsKxNk/u8HNhcVcuAze27JGmIBhYiVXUbsP+Q2l9W1cH29XZg0ZH2kWQB8OKqur2qClgPnNcWrwTWtfl1fXVJ0pCMckzkd4Gv9H0/Kcm3k3w9yWtbbSGwo2+dHa0GML+qdrX5x4D5A22tJOkpjhnFQZP8B+Ag8PlW2gUsqap9SU4DvpzklOnur6oqSR3heKuB1QBLlizp3nBJ0i8Y+pVIkncCvwW8vXVRUVWPV9W+Nr8N+C7wCmAnv9jltajVAHa37q7Jbq89hztmVa2pqvGqGh8bG5vhM5KkuWuoIZJkBfDvgTdX1U/66mNJ5rX5l9EbQH+odVcdSHJGuyvrIuCWttkGYFWbX9VXlyQNycC6s5LcALweODHJDuAKendjvRDY1O7Uvb3difU64Mok/wj8DHhPVU0Oyr+X3p1ev0xvDGVyHOVjwE1JLgYeAc4f1LlIkqY2sBCpqgunKF93mHVvBm4+zLKtwKlT1PcBb3w2bZQkPTs+sS5J6swQkSR1ZohIkjozRCRJnRkikqTODBFJUmeGiCSpM0NEktSZISJJ6mxaIZJk83RqkqS55YivPUnyS8Cv0Hv/1fFA2qIX8+TvekiS5qine3fWu4H3Ay8FtvFkiBwA/tvgmiVJmg2OGCJV9RngM0l+r6quGVKbJEmzxLTe4ltV1yT5TWBp/zZVtX5A7ZIkzQLTCpEkfwq8HLgLeKKVCzBEJGkOm+7viYwDJ0/+nK0kSTD950TuBf7pIBsiSZp9pnslciJwX5I7gMcni1X15oG0SpI0K0w3RD7cZedJ1gK/BeypqlNb7QTgC/QG6R8Gzq+qH6T3o+ufAc4FfgK8s6rubNusAv6w7fajVbWu1U/jyd9f3wi8zy43SRqeaXVnVdXXp5qmsen1wIpDapcDm6tqGbC5fQc4B1jWptXAtfDz0LkCeA1wOnBFe/CRts4lfdsdeixJ0gBN97UnP0pyoE1/n+SJJAeebruqug3Yf0h5JbCuza8Dzuurr6+e24HjkiwAzgY2VdX+qvoBsAlY0Za9uKpub1cf6/v2JUkaguk+J3Ls5HzrdloJnNHxmPOralebfwyY3+YXAo/2rbej1Y5U3zFFXZI0JM/4Lb7tSuHL9K4QnpV2BTHwMYwkq5NsTbJ17969gz6cJM0Z033Y8K19X59H77mRv+94zN1JFlTVrtYltafVdwKL+9Zb1Go7gdcfUv9aqy+aYv2nqKo1wBqA8fFxB94laYZM90rkX/dNZwM/otel1cUGYFWbXwXc0le/KD1nAD9s3V63AmclOb4NqJ8F3NqWHUhyRutiu6hvX5KkIZjumMi7uuw8yQ30riJOTLKD3l1WHwNuSnIx8Ahwflt9I73beyfo3eL7rnbs/Uk+Amxp611ZVZOD9e/lyVt8v9ImSdKQTLc7axFwDXBmK/01vWcydhx+K6iqCw+z6I1TrFvApYfZz1pg7RT1rcCpR2qDJGlwptud9Sf0upte2qb/2WqSpDlsuiEyVlV/UlUH23Q9MDbAdkmSZoHphsi+JL+TZF6bfgfYN8iGSZKe+6YbIr9LbwD8MWAX8NvAOwfUJknSLDHdFzBeCaxqrx2ZfJ/VJ+iFiyRpjprulcirJgMEerfdAq8eTJMkSbPFdEPkeX1vzp28EpnuVYwk6Sg13SD4JPDNJH/evr8NuGowTZIkzRbTfWJ9fZKtwBta6a1Vdd/gmiVJmg2m3SXVQsPgkCT93DN+FbwkSZMMEUlSZ4aIJKkzQ0SS1JkhIknqzBCRJHVmiEiSOjNEJEmdDT1EkrwyyV1904Ek70/y4SQ7++rn9m3zwSQTSR5IcnZffUWrTSS5fNjnIklz3dBfolhVDwDLAZLMA3YCXwLeBXy6qj7Rv36Sk4ELgFPo/TTvXyV5RVv8WeBNwA5gS5INvo5FkoZn1G/ifSPw3ap6JMnh1lkJ3FhVjwPfSzIBnN6WTVTVQwBJbmzrGiKSNCSjHhO5ALih7/tlSe5Osrbv1fMLgUf71tnRaoerS5KGZGQhkuQFwJuBydfLXwu8nF5X1y56r5+fqWOtTrI1yda9e/fO1G4lac4b5ZXIOcCdVbUboKp2V9UTVfUz4HM82WW1E1jct92iVjtc/Smqak1VjVfV+NjY2AyfhiTNXaMMkQvp68pKsqBv2VuAe9v8BuCCJC9MchKwDLgD2AIsS3JSu6q5oK0rSRqSkQysJ/lVendVvbuv/F+TLAcKeHhyWVVtT3ITvQHzg8ClVfVE289lwK3APGBtVW0f1jlIkkYUIlX1/4CXHFJ7xxHWv4opfo63qjYCG2e8gZKkaRn13VmSpFnMEJEkdWaISJI6M0QkSZ0ZIpKkzgwRSVJnhogkqTNDRJLUmSEiSerMEJEkdWaISJI6M0QkSZ0ZIpKkzgwRSVJnhogkqTNDRJLUmSEiSerMEJEkdTayEEnycJJ7ktyVZGurnZBkU5IH2+fxrZ4kVyeZSHJ3kt/o28+qtv6DSVaN6nwkaS4a9ZXIv6yq5VU13r5fDmyuqmXA5vYd4BxgWZtWA9dCL3SAK4DXAKcDV0wGjyRp8EYdIodaCaxr8+uA8/rq66vnduC4JAuAs4FNVbW/qn4AbAJWDLnNkjRnjTJECvjLJNuSrG61+VW1q80/Bsxv8wuBR/u23dFqh6tLkobgmBEe+19U1c4k/wTYlOQ7/QurqpLUTByohdRqgCVLlszELiVJjPBKpKp2ts89wJfojWnsbt1UtM89bfWdwOK+zRe12uHqhx5rTVWNV9X42NjYTJ+KJM1ZIwmRJL+a5NjJeeAs4F5gAzB5h9Uq4JY2vwG4qN2ldQbww9btdStwVpLj24D6Wa0mSRqCUXVnzQe+lGSyDX9WVf8ryRbgpiQXA48A57f1NwLnAhPAT4B3AVTV/iQfAba09a6sqv3DOw1JmttGEiJV9RDwz6eo7wPeOEW9gEsPs6+1wNqZbqMk6ek9127xlSTNIoaIJKkzQ0SS1JkhIknqbJQPG0rSlM685sxRN+Go943f+8aM7McrEUlSZ4aIJKkzQ0SS1JkhIknqzBCRJHVmiEiSOjNEJEmdGSKSpM4MEUlSZ4aIJKkzQ0SS1JkhIknqzBCRJHU29BBJsjjJV5Pcl2R7kve1+oeT7ExyV5vO7dvmg0kmkjyQ5Oy++opWm0hy+bDPRZLmulG8Cv4g8AdVdWeSY4FtSTa1ZZ+uqk/0r5zkZOAC4BTgpcBfJXlFW/xZ4E3ADmBLkg1Vdd9QzkKSNPwQqapdwK42/6Mk9wMLj7DJSuDGqnoc+F6SCeD0tmyiqh4CSHJjW9cQkaQhGemYSJKlwKuBb7XSZUnuTrI2yfGtthB4tG+zHa12uLokaUhGFiJJXgTcDLy/qg4A1wIvB5bTu1L55Awea3WSrUm27t27d6Z2K0lz3khCJMnz6QXI56vqiwBVtbuqnqiqnwGf48kuq53A4r7NF7Xa4epPUVVrqmq8qsbHxsZm9mQkaQ4bxd1ZAa4D7q+qT/XVF/St9hbg3ja/AbggyQuTnAQsA+4AtgDLkpyU5AX0Bt83DOMcJEk9o7g760zgHcA9Se5qtQ8BFyZZDhTwMPBugKranuQmegPmB4FLq+oJgCSXAbcC84C1VbV9eKchSRrF3Vn/B8gUizYeYZurgKumqG880naSpMHyiXVJUmeGiCSpM0NEktSZISJJ6swQkSR1ZohIkjobxXMis8Zp/279qJtw1Nv28YtG3QRJz4JXIpKkzgwRSVJnhogkqTNDRJLUmSEiSerMEJEkdWaISJI6M0QkSZ0ZIpKkzgwRSVJnhogkqbNZHyJJViR5IMlEkstH3R5JmktmdYgkmQd8FjgHOBm4MMnJo22VJM0dszpEgNOBiap6qKr+AbgRWDniNknSnDHbQ2Qh8Gjf9x2tJkkagjnxeyJJVgOr29cfJ3lglO0ZsBOB74+6EdOVT6wadROeS2bV3w6AKzLqFjyXzKq/X37/Gf/tfm2q4mwPkZ3A4r7vi1rtF1TVGmDNsBo1Skm2VtX4qNuhZ86/3ew2V/9+s707awuwLMlJSV4AXABsGHGbJGnOmNVXIlV1MMllwK3APGBtVW0fcbMkac6Y1SECUFUbgY2jbsdzyJzotjtK+beb3ebk3y9VNeo2SJJmqdk+JiJJGiFD5Cjh619mryRrk+xJcu+o26JnJsniJF9Ncl+S7UneN+o2DZvdWUeB9vqXvwXeRO+Byy3AhVV130gbpmlJ8jrgx8D6qjp11O3R9CVZACyoqjuTHAtsA86bS//2vBI5Ovj6l1msqm4D9o+6HXrmqmpXVd3Z5n8E3M8ce2uGIXJ08PUv0oglWQq8GvjWiJsyVIaIJD1LSV4E3Ay8v6oOjLo9w2SIHB2m9foXSTMvyfPpBcjnq+qLo27PsBkiRwdf/yKNQJIA1wH3V9WnRt2eUTBEjgJVdRCYfP3L/cBNvv5l9khyA/BN4JVJdiS5eNRt0rSdCbwDeEOSu9p07qgbNUze4itJ6swrEUlSZ4aIJKkzQ0SS1JkhIknqzBCRJHVmiEhDlOTDSf7tqNshzRRDRJLUmSEiDVCSi5LcneRvkvzpIcsuSbKlLbs5ya+0+tuS3Nvqt7XaKUnuaA+z3Z1k2SjORzqUDxtKA5LkFOBLwG9W1feTnAD8PvDjqvpEkpdU1b627keB3VV1TZJ7gBVVtTPJcVX1d0muAW6vqs+3V9vMq6qfjurcpEleiUiD8wbgz6vq+wBVdehvhpya5K9baLwdOKXVvwFcn+QSYF6rfRP4UJIPAL9mgOi5whCRRud64LKq+nXgPwG/BFBV7wH+kN6bmbe1K5Y/A94M/BTYmOQNo2my9IsMEWlw/jfwtiQvAWjdWf2OBXa1V4m/fbKY5OVV9a2q+iNgL7A4ycuAh6rqauAW4FVDOQPpaRwz6gZIR6uq2p7kKuDrSZ4Avg083LfKf6T3K3h72+exrf7xNnAeYDPwN8AHgHck+UfgMeA/D+UkpKfhwLokqTO7syRJnRkikqTODBFJUmeGiCSpM0NEktSZISJJ6swQkSR1ZohIkjr7/wSssrpswJlgAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import seaborn as sns\n",
    "sns.countplot('class', data=pandas_df)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Unnamed: 0</th>\n",
       "      <th>count</th>\n",
       "      <th>hate_speech</th>\n",
       "      <th>offensive_language</th>\n",
       "      <th>neither</th>\n",
       "      <th>class</th>\n",
       "      <th>tweet</th>\n",
       "      <th>tweet_cleaned</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>!!! RT @mayasolovely: As a woman you shouldn't...</td>\n",
       "      <td>!!! RT : As a woman you shouldn't complain abo...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>!!!!! RT @mleew17: boy dats cold...tyga dwn ba...</td>\n",
       "      <td>!!!!! RT : boy dats cold...tyga dwn bad for cu...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>!!!!!!! RT @UrKindOfBrand Dawg!!!! RT @80sbaby...</td>\n",
       "      <td>!!!!!!! RT Dawg!!!! RT : You ever fuck a bitch...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>!!!!!!!!! RT @C_G_Anderson: @viva_based she lo...</td>\n",
       "      <td>!!!!!!!!! RT _G_Anderson: _based she look like...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>6</td>\n",
       "      <td>0</td>\n",
       "      <td>6</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>!!!!!!!!!!!!! RT @ShenikaRoberts: The shit you...</td>\n",
       "      <td>!!!!!!!!!!!!! RT : The shit you hear about me ...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Unnamed: 0  count  hate_speech  offensive_language  neither  class  \\\n",
       "0           0      3            0                   0        3      2   \n",
       "1           1      3            0                   3        0      1   \n",
       "2           2      3            0                   3        0      1   \n",
       "3           3      3            0                   2        1      1   \n",
       "4           4      6            0                   6        0      1   \n",
       "\n",
       "                                               tweet  \\\n",
       "0  !!! RT @mayasolovely: As a woman you shouldn't...   \n",
       "1  !!!!! RT @mleew17: boy dats cold...tyga dwn ba...   \n",
       "2  !!!!!!! RT @UrKindOfBrand Dawg!!!! RT @80sbaby...   \n",
       "3  !!!!!!!!! RT @C_G_Anderson: @viva_based she lo...   \n",
       "4  !!!!!!!!!!!!! RT @ShenikaRoberts: The shit you...   \n",
       "\n",
       "                                       tweet_cleaned  \n",
       "0  !!! RT : As a woman you shouldn't complain abo...  \n",
       "1  !!!!! RT : boy dats cold...tyga dwn bad for cu...  \n",
       "2  !!!!!!! RT Dawg!!!! RT : You ever fuck a bitch...  \n",
       "3  !!!!!!!!! RT _G_Anderson: _based she look like...  \n",
       "4  !!!!!!!!!!!!! RT : The shit you hear about me ...  "
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pandas_df['tweet_cleaned'] = pandas_df['tweet'].str.replace('@[A-Za-z0-9]+\\s?', '', regex=True)\n",
    "pandas_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load dataframe from Pandas"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['Unnamed: 0', 'count', 'hate_speech', 'offensive_language', 'neither', 'class', 'tweet', 'tweet_cleaned'],\n",
       "    num_rows: 24783\n",
       "})"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import Dataset\n",
    "\n",
    "ds = Dataset.from_pandas(pandas_df)\n",
    "ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using custom data configuration default-f1b446edcc2fbcfb\n",
      "Reusing dataset csv (/home/paul/.cache/huggingface/datasets/csv/default-f1b446edcc2fbcfb/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['Unnamed: 0', 'count', 'hate_speech', 'offensive_language', 'neither', 'class', 'tweet'],\n",
       "    num_rows: 24783\n",
       "})"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "dataset = load_dataset('csv', data_files=DATA_PATH, split='train')\n",
    "\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['class', 'tweet', 'tweet_cleaned'],\n",
       "        num_rows: 18587\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['class', 'tweet', 'tweet_cleaned'],\n",
       "        num_rows: 1549\n",
       "    })\n",
       "    valid: Dataset({\n",
       "        features: ['class', 'tweet', 'tweet_cleaned'],\n",
       "        num_rows: 4647\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_test_valid = ds.train_test_split()\n",
    "\n",
    "test_valid = train_test_valid['test'].train_test_split()\n",
    "\n",
    "train_test_valid_dataset = DatasetDict({\n",
    "    'train': train_test_valid['train'],\n",
    "    'test': test_valid['test'],\n",
    "    'valid': test_valid['train']\n",
    "    })\n",
    "\n",
    "\n",
    "dataset = train_test_valid_dataset.remove_columns(['hate_speech', 'offensive_language', 'neither','Unnamed: 0', 'count'])\n",
    "dataset\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': [101, 2066, 9444, 22559, 2734, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1]}"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "text = \"Just checking tokenization\"\n",
    "\n",
    "output = tokenizer(text)\n",
    "\n",
    "output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['[CLS]', 'Just', 'checking', 'token', '##ization', '[SEP]']"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokens = tokenizer.convert_ids_to_tokens(output['input_ids'])\n",
    "tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tokenized text: [CLS] Just checking tokenization [SEP]\n"
     ]
    }
   ],
   "source": [
    "print(f\"Tokenized text: {tokenizer.convert_tokens_to_string(tokens)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Vocab size is : 28996\n",
      "Model max length is : 512\n",
      "Model input names are: ['input_ids', 'token_type_ids', 'attention_mask']\n"
     ]
    }
   ],
   "source": [
    "print(f\"Vocab size is : {tokenizer.vocab_size}\")\n",
    "\n",
    "print(f\"Model max length is : {tokenizer.model_max_length}\")\n",
    "\n",
    "print(f\"Model input names are: {tokenizer.model_input_names}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['Unnamed: 0', 'count', 'hate_speech', 'offensive_language', 'neither', 'class', 'tweet', 'tweet_cleaned'],\n",
       "    num_rows: 24783\n",
       "})"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 19/19 [00:03<00:00,  5.26ba/s]\n",
      "100%|██████████| 2/2 [00:00<00:00,  4.72ba/s]\n",
      "100%|██████████| 5/5 [00:00<00:00,  5.65ba/s]\n"
     ]
    }
   ],
   "source": [
    "def tokenize_function(train_dataset):\n",
    "    return tokenizer(train_dataset['tweet_cleaned'], padding='max_length', truncation=True) \n",
    "\n",
    "\n",
    "tokenized_dataset = dataset.map(tokenize_function, batched=True)\n",
    "\n",
    "tokenized_dataset\n",
    "\n",
    "train_dataset = tokenized_dataset['train']\n",
    "eval_dataset = tokenized_dataset['valid']\n",
    "test_dataset = tokenized_dataset['test']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['class', 'tweet', 'tweet_cleaned', 'input_ids', 'token_type_ids', 'attention_mask'],\n",
       "    num_rows: 18587\n",
       "})"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_set = train_dataset.remove_columns(['tweet', \"tweet_cleaned\"]).with_format('tensorflow')\n",
    "\n",
    "tf_eval_dataset = eval_dataset.remove_columns(['tweet', \"tweet_cleaned\"]).with_format('tensorflow')\n",
    "\n",
    "tf_test_dataset = test_dataset.remove_columns(['tweet', \"tweet_cleaned\"]).with_format('tensorflow')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_features = { x: train_set[x] for x in tokenizer.model_input_names  }\n",
    "\n",
    "train_set_for_final_model = tf.data.Dataset.from_tensor_slices((train_features, train_set['class'] ))\n",
    "\n",
    "train_set_for_final_model = train_set_for_final_model.shuffle(len(train_set)).batch(8)\n",
    "\n",
    "\n",
    "eval_features = {x: tf_eval_dataset[x] for x in tokenizer.model_input_names}\n",
    "val_set_for_final_model = tf.data.Dataset.from_tensor_slices((eval_features, tf_eval_dataset[\"class\"]))\n",
    "val_set_for_final_model = val_set_for_final_model.batch(8)\n",
    "\n",
    "test_features = {x: tf_test_dataset[x] for x in tokenizer.model_input_names}\n",
    "test_set_for_final_model = tf.data.Dataset.from_tensor_slices((test_features, tf_test_dataset[\"class\"]))\n",
    "test_set_for_final_model =test_set_for_final_model.batch(8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = TFAutoModelForSequenceClassification.from_pretrained(\"bert-base-cased\", num_labels=3)\n",
    "# model = TFAutoModelForSequenceClassification.from_pretrained(\"/mnt/e0ccdbdb-22c3-4d9b-9413-fd976a2e99ae/M1/Code_Org/HF_Models/bert-base-uncased\", num_labels=3)\n",
    "\n",
    "\n",
    "model.compile(\n",
    "    optimizer=tf.keras.optimizers.Adam(learning_rate=5e-5),\n",
    "    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
    "    metrics=tf.metrics.SparseCategoricalAccuracy(),\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "history = model.fit(train_set_for_final_model, validation_data=val_set_for_final_model, epochs=3 )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(history.history['sparse_categorical_accuracy'])\n",
    "plt.plot(history.history['val_sparse_categorical_accuracy'])\n",
    "plt.title('model sparse categorical accuracy')\n",
    "plt.ylabel('accuracy')\n",
    "plt.xlabel('epoch')\n",
    "plt.legend(['train', 'val'], loc='upper left')\n",
    "plt.show()\n",
    "\n",
    "\n",
    "plt.plot(history.history['loss'])\n",
    "plt.plot(history.history['val_loss'])\n",
    "plt.title('model loss')\n",
    "plt.ylabel('loss')\n",
    "plt.xlabel('epoch')\n",
    "plt.legend(['train', 'val'], loc='upper left')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_loss, test_acc = model.evaluate(test_set_for_final_model,verbose=2)\n",
    "print('\\nTest accuracy:', test_acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predict_score_and_class_dict = \n",
    "{0: 'Hate Speech',\n",
    " 1: 'Offensive Language',\n",
    " 2: 'Neither'}\n",
    "\n",
    "preds = model(tokenizer([\"He is useless, I dont know why he came to our neighbourhood\", \"That guy sucks\", \"He is such a retard\"],return_tensors=\"tf\",padding=True,truncation=True))['logits']\n",
    "\n",
    "print(preds)\n",
    "\n",
    "class_preds = np.argmax(preds, axis=1)\n",
    "\n",
    "for pred in class_preds:\n",
    "  print(predict_score_and_class_dict[pred])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predict_score_and_class_dict = {0: 'Hate Speech',\n",
    " 1: 'Offensive Language',\n",
    " 2: 'Neither'}\n",
    "preds = model(tokenizer([\"He dresses up like a begger thise days\"],return_tensors=\"tf\",padding=True,truncation=True))['logits']\n",
    "print(preds)\n",
    "class_preds = np.argmax(preds, axis=1)\n",
    "\n",
    "for pred in class_preds:\n",
    "  print(predict_score_and_class_dict[pred])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.13 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "f9f85f796d01129d0dd105a088854619f454435301f6ffec2fea96ecbd9be4ac"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
